--[[ Copyright (C) 2018 Google Inc.

This program is free software; you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation; either version 2 of the License, or
(at your option) any later version.

This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License along
with this program; if not, write to the Free Software Foundation, Inc.,
51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
]]

local game = require 'dmlab.system.game'
local tensor = require 'dmlab.system.tensor'
local random = require 'common.random'
local pickups = require 'common.pickups'
local custom_observations = require 'decorators.custom_observations'
local timeout = require 'decorators.timeout'
local map_maker = require 'dmlab.system.map_maker'
local randomMap = random(map_maker:randomGen())

local OBJECT_NUM = 10

local PICKUPS = {
    cake = {
        name = 'Cake',
        classname = 'cake',
        model = 'models/hr_cake.md3',
        quantity = 10,
        type = pickups.type.REWARD,
    },
    box = {
        name = 'Box',
        classname = 'box',
        model = 'models/fut_obj_cube_01.md3',
        quantity = -3,
        type = pickups.type.REWARD,
    },
}

local ITEMS = {'apple_reward', 'box', 'cake'}

local ITEMS_ID = {
    apple_reward = 1,
    box = 2,
    cake = 3,
}

local MAZE_SIZE = 5
local GRID_SIZE = 120
local ITEMS_OFFSET_Z = 24

local ROOMS_BOUNDARY = {
    {
        minX = -208,
        minY = -240,
    },
    {
        minX = 520,
        minY = -240,
    },
}

local FLOOR_TEXTURE = 'textures/map/script_highlight'

local factory = {}

--[[ Creates an API for exploit deferred effects tasks.
The goal of this level is to learn the consequence of actions in non-overlapping
conditions.
The map of the level can be configured by kwargs.configs. If multiple configs
are given, the level will randomly choose one of them in each episode. The map
finishes when the player has picked up 10 objects or the time limit is reached.
Each epsisode contains exactly 1 map.

rooms_exploit_deferred_effects_train.lua contains the 3 configs for training:

1.  One room, with 10 apples and 1 box, box gives negative reward but will make
    the floor bright.
2.  Two rooms, there are 10 apples in the first room, and 10 cakes in the second
    one. The player spawns in the first room, the floor is dark, and there is no
    way to open the door between the first and second room.
3.  Same as scenario 2 except that the floor is bright and the door is open at
    first.

In all scenarios, an apple worth 1 point and an cake worth 10 points.

rooms_exploit_deferred_effects_test.lua contains the config for testing:

Two rooms, there are 10 apples and a box in the first room, and 10 cakes in the
second room. The player spawns in the first room, the floor is dark, and the
door is closed.

Ideally we hope the agent could learn the rules:

1.  Picking up box can turn floor bright
2.  Bright floor leads to the opened door, which leads to the cakes with higher
    points.

Keyword arguments:

*   `episodeLengthSeconds` (number) Episode length in seconds.
*   `configs` (table) Arguments used to decide the configuration of the rooms
    and pickups.
]]
function factory.createLevelApi(kwargs)
  assert(kwargs.episodeLengthSeconds)
  assert(kwargs.configs)

  local darkFloorTexture = tensor.ByteTensor(8, 8, 4):fill(0)
  local brightFloorTexture = tensor.ByteTensor(8, 8, 4):fill(70)

  local api = {}

  local function shuffledPositions(height, width, count)
    if height * width < count then
      error('There are not enough posistions to pick from.')
    end
    local positions = {}
    for i = 1, height do
      for j = 1, width do
        positions[#positions + 1] = {i, j}
      end
    end
    return random:shuffleInPlace(positions, count)
  end

  function api:start(episode, seed)
    random:seed(seed)
    randomMap:seed(seed)
    api._timeOut = nil
  end

  function api:nextMap()
    local config = random:choice(kwargs.configs)

    local room1, room2
    room1 = shuffledPositions(MAZE_SIZE, MAZE_SIZE, OBJECT_NUM + 2)
    if config.roomCount == 2 then
      room2 = shuffledPositions(MAZE_SIZE, MAZE_SIZE, OBJECT_NUM)
    end

    api._rooms = {room1, room2}
    api._doorOpened = config.doorOpened
    api._hasBox = config.hasBox
    api._pickupCount = 0

    if config.roomCount == 1 then
      return 'rooms_exploit_deferred_effects_one_room'
    else
      return 'rooms_exploit_deferred_effects_two_rooms'
    end
  end

  function api:createPickup(classname)
    return PICKUPS[classname] or pickups.defaults[classname]
  end

  function api:pickup(entityId)
    local object = ITEMS[entityId]
    assert(object)
    if object == 'box' then
      game:updateTexture(FLOOR_TEXTURE, brightFloorTexture)
    end
    api._pickupCount = api._pickupCount + 1
    if api._pickupCount >= 10 then
      api._timeOut = api._time + 0.2
    end
  end

  function api:_getOrigin(roomId, inx)
    local pos = api._rooms[roomId][inx]
    local x = (pos[1] - 1) * GRID_SIZE + ROOMS_BOUNDARY[roomId].minX
    local y = (pos[2] - 1) * GRID_SIZE + ROOMS_BOUNDARY[roomId].minY
    return x .. ' ' .. y .. ' ' .. ITEMS_OFFSET_Z
  end

  function api:updateSpawnVars(spawnVars)
    if spawnVars.classname == 'info_player_start' then
      spawnVars.origin = api:_getOrigin(1, OBJECT_NUM + 2)
    elseif spawnVars.classname == 'apple_reward' then
      spawnVars.id = tostring(ITEMS_ID[spawnVars.classname])
      local inx = tonumber(spawnVars.script_id)
      spawnVars.origin = api:_getOrigin(1, inx)
      spawnVars.wait = '-1'
    elseif spawnVars.classname == 'box' then
      if api._hasBox then
        spawnVars.id = tostring(ITEMS_ID[spawnVars.classname])
        spawnVars.origin = api:_getOrigin(1, OBJECT_NUM + 1)
        spawnVars.wait = '-1'
      else
        return nil
      end
    elseif spawnVars.classname == 'cake' then
      spawnVars.id = tostring(ITEMS_ID[spawnVars.classname])
      local inx = tonumber(spawnVars.script_id)
      spawnVars.origin = api:_getOrigin(2, inx)
      spawnVars.wait = '-1'
    elseif spawnVars.classname == 'func_door' then
      if api._doorOpened then
        spawnVars.spawnflags = '1'
      end
    end
    return spawnVars
  end

  function api:loadTexture(textureName)
    if textureName == FLOOR_TEXTURE then
      if api._doorOpened then
        return brightFloorTexture
      else
        return darkFloorTexture
      end
    end
  end

  function api:modifyTexture(textureName, tensorData)
    if textureName == 'textures/model/hr_cake_d' then
      local function overrideDark(amount)
        return function(val)
          return val < 128 and amount or val
        end
      end
      tensorData:select(3, 1):apply(overrideDark(200))
      tensorData:select(3, 2):apply(overrideDark(0))
      tensorData:select(3, 3):apply(overrideDark(0))
      return true
    end
    return false
  end

  function api:hasEpisodeFinished(timeSeconds)
    api._time = timeSeconds
    return api._timeOut and timeSeconds > api._timeOut
  end

  custom_observations.decorate(api)
  timeout.decorate(api, kwargs.episodeLengthSeconds)
  return api
end

return factory
